import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import os

class Transpose(nn.Module):
    """ Wrapper class of torch.transpose() for Sequential module. """
    def __init__(self, shape: tuple):
        super(Transpose, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.transpose(*self.shape)

class Conv_MLP(nn.Module):
    def __init__(self, in_dim, out_dim, resid_pdrop=0.):
        super().__init__()
        self.sequential = nn.Sequential(
            Transpose(shape=(1, 2)),
            nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1),
            nn.Dropout(p=resid_pdrop),
        )

    def forward(self, x):
        return self.sequential(x).transpose(1, 2)
    
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2048):
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Each position gets its own embedding
        # Since indices are always 0 ... max_len, we don't have to do a look-up
        self.pe = nn.Parameter(torch.empty(1, max_len, d_model))  # requires_grad automatically set to True
        nn.init.uniform_(self.pe, -0.02, 0.02)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [batch size, sequence length, embed dim]
            output: [batch size, sequence length, embed dim]
        """
        x = x + self.pe
        return self.dropout(x)
    
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
    
class AdaLayerNorm(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.emb = SinusoidalPosEmb(n_embd)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(n_embd, n_embd*2)
        self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)

    def forward(self, x, timestep, label_emb=None):
        emb = self.emb(timestep)
        if label_emb is not None:
            emb = emb + label_emb
        emb = self.linear(self.silu(emb)).unsqueeze(1)
        scale, shift = torch.chunk(emb, 2, dim=2)
        x = self.layernorm(x) * (1 + scale) + shift
        return x

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_attn_weights = None
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        x = src
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))
        return x
    
    def _sa_block(self, x, attn_mask, key_padding_mask):
        x, attn_weights = self.self_attn(
            x, x, x,
            attn_mask=attn_mask,
            key_padding_mask=key_padding_mask,
            need_weights=True
        )
        self.last_attn_weights = attn_weights
        return self.dropout1(x)

class CustomTransformerEncoder(nn.TransformerEncoder):
    def __init__(self, encoder_layer, num_layers):
        super().__init__(encoder_layer, num_layers)
        
    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src
        self.attention_weights = []
        
        for layer in self.layers:
            output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
            self.attention_weights.append(layer.last_attn_weights)
            
        return output

class DiffusionModel(nn.Module):
    def __init__(self, input_dim, original_input_dim, d_model, num_heads, num_layers, hidden_dim, max_seq_len=2048):
        super(DiffusionModel, self).__init__()
        self.d_model = d_model
        self.input_dim = input_dim
        self.original_input_dim = original_input_dim
        
        # Input embedding and output projection
        self.emb = Conv_MLP(input_dim, d_model, resid_pdrop=0)
        self.inverse = Conv_MLP(d_model, original_input_dim, resid_pdrop=0)
        
        # Positional encoding
        self.pos_enc = LearnablePositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        
        # Adaptive Layer Normalization
        self.ln1 = AdaLayerNorm(d_model)
        
        # Dropout layer
        self.dropout = nn.Dropout(p=0.1)
        
        # Create custom transformer encoder layer
        self.encoder_layer = CustomTransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            batch_first=True,
            activation='gelu',
            dropout=0.1
        )
        
        # Create custom transformer encoder
        self.transformer_encoder = CustomTransformerEncoder(
            self.encoder_layer,
            num_layers=num_layers
        )
        
        # Initialize attention weights storage
        self.attention_weights = []

    def forward(self, x, t):
        """
        Forward pass of the model.
        Args:
            x: Input tensor of shape [batch_size, seq_len, input_dim]
            t: Timestep tensor of shape [batch_size]
        Returns:
            Output tensor of shape [batch_size, seq_len, original_input_dim]
        """
        batch_size, seq_len, _ = x.size()
        
        # Input embedding
        x = self.emb(x)
        
        # Add positional encoding
        x = self.pos_enc(x)
        
        # Apply adaptive layer norm
        x = self.ln1(x, t)
        
        # Pass through transformer encoder
        x = self.transformer_encoder(x)
        
        # Store attention weights
        self.attention_weights = self.transformer_encoder.attention_weights
        
        # Project back to original dimension
        x = self.inverse(x)
        
        return x

    def get_attention_weights(self):
        """
        Returns the attention weights from the last forward pass.
        Returns:
            List of tensors, one for each layer.
            Each tensor has shape [batch_size, num_heads, seq_len, seq_len]
        """
        return self.attention_weights

def save_attention_weights(attention_weights, save_dir, iteration):
    """
    Save attention weights to disk.
    Args:
        attention_weights: List of attention weight tensors
        save_dir: Directory to save the weights
        iteration: Current iteration number
    """
    os.makedirs(save_dir, exist_ok=True)
    
    for layer_idx, layer_weights in enumerate(attention_weights):
        save_path = os.path.join(save_dir, f'attn_weights_iter{iteration}_layer{layer_idx}.pt')
        torch.save(layer_weights, save_path)

def visualize_attention_weights(attention_weights, layer_idx, head_idx=0, batch_idx=0):
    """
    Visualize attention weights as a heatmap.
    Args:
        attention_weights: List of attention weight tensors
        layer_idx: Index of the layer to visualize
        head_idx: Index of the attention head to visualize
        batch_idx: Index of the batch item to visualize
    """
    
    weights = attention_weights[layer_idx][batch_idx, head_idx].cpu().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(weights, cmap='viridis')
    plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.savefig(f'attn_weights_layer{layer_idx}_head{head_idx}_batch{batch_idx}.png')